{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#cifar\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import backend\n",
    "from tensorflow.keras import layers\n",
    "import os\n",
    "import numpy as np \n",
    "from matplotlib import pyplot as plt\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create CNN\n",
    "def CNNmodel(input_shape,filters=64, kernel=(3,3),size=4,dropout=0.2,**kwargs):\n",
    "    _inputs = layers.Input(shape=input_shape)\n",
    "    x = layers.Conv2D(8,(3,3),padding='same',use_bias=False,strides=(2,2), name='conv_0')(_inputs)\n",
    "    x = layers.BatchNormalization(axis=-1, name='conv_0_bn')(x)\n",
    "    x = layers.ReLU(6., name='conv_0_relu')(x)\n",
    "    \n",
    "    x = layers.Conv2D(16,(3,3),padding='same',use_bias=False,strides=(2,2), name='conv_1')(_inputs)\n",
    "    x = layers.BatchNormalization(axis=-1, name='conv_1_bn')(x)\n",
    "    x = layers.ReLU(6., name='conv_1_relu')(x)\n",
    "\n",
    "    for block_id in range(2,size+2):\n",
    "        x = layers.Conv2D(filters,kernel,padding='same',use_bias=False,strides=(1,1), name='conv_%d'%block_id)(x)\n",
    "        x = layers.BatchNormalization(axis=-1, name='conv_%d_bn'%block_id)(x)\n",
    "        x = layers.ReLU(6., name='conv_%d_relu'%block_id)(x)\n",
    "\n",
    "    x = layers.GlobalAveragePooling2D()(x)\n",
    "    x = layers.Dropout(dropout, name='dropout')(x)\n",
    "    x = layers.Dense(10)(x)\n",
    "    x = layers.Softmax()(x)\n",
    "    return keras.Model(inputs=_inputs,outputs=x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# preprocess input\n",
    "def preprocess_input(inputs, std=255. ,mean=0., expand_dims=None):\n",
    "    inputs = tf.cast(inputs,tf.float32)\n",
    "    inputs = (inputs - mean) / std\n",
    "    if expand_dims is not None:\n",
    "        np.expand_dims(inputs,expand_dims)\n",
    "    return inputs\n",
    "\n",
    "# dataset aug\n",
    "def img_aug_fun(elem):\n",
    "    elem = tf.image.random_flip_left_right(elem)#左右翻转\n",
    "    elem = tf.image.random_brightness(elem, max_delta=0.5)#调亮度\n",
    "    elem = tf.image.random_contrast(elem, lower=0.5, upper=1.5)#调对比度\n",
    "    elem = preprocess_input(elem)\n",
    "    return elem\n",
    "\n",
    "# load CIFAR10 dataset, size(32,32,3)\n",
    "(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()\n",
    "assert x_train.shape == (50000, 32, 32, 3)\n",
    "assert x_test.shape == (10000, 32, 32, 3)\n",
    "assert y_train.shape == (50000, 1)\n",
    "assert y_test.shape == (10000, 1)\n",
    "\n",
    "x_test = preprocess_input(x_test)\n",
    "x_train_ds = tf.data.Dataset.from_tensor_slices(x_train).map(img_aug_fun)\n",
    "y_train_ds = tf.data.Dataset.from_tensor_slices(y_train)\n",
    "x_y_train_ds = tf.data.Dataset.zip((x_train_ds,y_train_ds))\n",
    "x_y_train_ds = x_y_train_ds.batch(128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 训练代码\n",
    "\n",
    "reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor='accuracy', factor=0.5, patience=4, min_lr=0.0001,verbose=1)\n",
    "earlystop = keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=8,verbose=1)\n",
    "\n",
    "model = CNNmodel(input_shape=(32,32,3),filters=64, kernel=(3,3),size=9)\n",
    "model.compile(optimizer='SGD',loss='sparse_categorical_crossentropy',metrics=['accuracy'])\n",
    "history = model.fit(x_y_train_ds,validation_data=(x_test,y_test),callbacks=[reduce_lr,earlystop],verbose=2,epochs=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(history.history['val_accuracy'],label='val_acc')\n",
    "plt.legend()\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Acc')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save('checkpoint/Cifar10_CNN_%.3f'%history.history['val_accuracy'][-1]+'.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
